#!/usr/bin/env python3
"""
Supplementary Code: Warming event analysis for iron-based ionic liquids on Mars

Reproduces the stochastic atmospheric evolution model of Wordsworth et al.
(2021, Nat. Geosci. 14, 127-132) and extends the warming event analysis of
Turner & Kite (2026, GRL) to temperature thresholds below 243 K.

Requires:
    - Python 3.8+
    - numpy, scipy, numba, matplotlib
    - PCM_data.mat from https://github.com/wordsworthgroup/mars_redox_2021

Usage:
    python warming_events_SI.py

Output:
    - Figure_SI_warming_events.pdf
    - warming_events_data.csv (raw scatter data for each threshold)
    - warming_events_summary.csv (median, P10-P90 statistics)
"""

import numpy as np
from scipy.io import loadmat
from scipy.interpolate import RegularGridInterpolator
from scipy.optimize import brentq
from numba import njit
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import csv
import os
import time as timer


# ============================================================================
# 1. CLIMATE GRID
# ============================================================================

def load_pcm_data(path='PCM_data.mat'):
    """Load the 7x8x5 PCM-LBL climate lookup table."""
    data = loadmat(path)
    Fsol = data['Fsol']['data'][0, 0].flatten().astype(np.float64)
    ps = data['ps']['data'][0, 0].flatten().astype(np.float64)
    fH2 = data['fH2']['data'][0, 0].flatten().astype(np.float64)
    Ts = data['Ts']['data'][0, 0].astype(np.float64)
    return Fsol, np.log10(ps), fH2, Ts


def build_fine_grid(Fsol, log_ps, fH2, Ts, nF=50, nP=60, nH=40):
    """Pre-interpolate coarse grid to fine grid using cubic spline.
    Enables fast numba trilinear lookup with <0.2 K error."""
    interp = RegularGridInterpolator(
        (Fsol, log_ps, fH2), Ts, method='cubic',
        bounds_error=False, fill_value=None)
    Ff = np.linspace(Fsol[0], Fsol[-1], nF)
    Lf = np.linspace(log_ps[0], log_ps[-1], nP)
    Hf = np.linspace(fH2[0], fH2[-1], nH)
    FF, LP, FH = np.meshgrid(Ff, Lf, Hf, indexing='ij')
    pts = np.column_stack([FF.ravel(), LP.ravel(), FH.ravel()])
    return Ff, Lf, Hf, interp(pts).reshape(nF, nP, nH)


# ============================================================================
# 2. NUMBA-ACCELERATED CORE
# ============================================================================

@njit
def trilinear(F, lp, fh, Fsol, lps, fH2, Ts):
    """Fast trilinear interpolation on uniform grid."""
    nF, nP, nH = len(Fsol), len(lps), len(fH2)
    F = max(Fsol[0], min(F, Fsol[-1]))
    lp = max(lps[0], min(lp, lps[-1]))
    fh = max(fH2[0], min(fh, fH2[-1]))

    iF = max(0, min(int((F - Fsol[0]) / (Fsol[-1] - Fsol[0]) * (nF - 1)), nF - 2))
    iP = max(0, min(int((lp - lps[0]) / (lps[-1] - lps[0]) * (nP - 1)), nP - 2))
    iH = max(0, min(int((fh - fH2[0]) / (fH2[-1] - fH2[0]) * (nH - 1)), nH - 2))

    xF = (F - Fsol[iF]) / (Fsol[iF + 1] - Fsol[iF])
    xP = (lp - lps[iP]) / (lps[iP + 1] - lps[iP])
    xH = (fh - fH2[iH]) / (fH2[iH + 1] - fH2[iH])

    c00 = Ts[iF, iP, iH] * (1 - xF) + Ts[iF + 1, iP, iH] * xF
    c01 = Ts[iF, iP, iH + 1] * (1 - xF) + Ts[iF + 1, iP, iH + 1] * xF
    c10 = Ts[iF, iP + 1, iH] * (1 - xF) + Ts[iF + 1, iP + 1, iH] * xF
    c11 = Ts[iF, iP + 1, iH + 1] * (1 - xF) + Ts[iF + 1, iP + 1, iH + 1] * xF
    c0 = c00 * (1 - xP) + c10 * xP
    c1 = c01 * (1 - xP) + c11 * xP
    return c0 * (1 - xH) + c1 * xH


@njit
def run_single(nt, dt, T_total, g_exp, x0_0, x1_0, td_supply, rands,
               Fsol_arr, uCO2_arr, Fsol_fg, lps_fg, fH2_fg, Ts_fg,
               muCO2, muH2, g_mars, Area, NA, My,
               H_CO2_val, b_H2CO2, conv, weather_rate,
               alpha_DH, H_loss_now, td_DH, gamma,
               N_CO2_factor, fH2_max):
    """Single realization of the Wordsworth et al. (2021) stochastic model.
    Returns surface temperature time series [K]."""
    Tsurf = np.zeros(nt)
    N = 0.0
    gp1 = g_exp + 1.0

    for it in range(nt):
        tt = (it + 1) * dt

        # Supply: stochastic reducing gas input
        S = np.exp(-tt / td_supply)
        x0 = x0_0 * S
        x1 = x1_0 * S
        val = (x1**gp1 - x0**gp1) * rands[it] + x0**gp1
        dNdt_s = -(val ** (1.0 / gp1)) if val > 0 else 0.0

        uCO2 = uCO2_arr[it]
        N_CO2 = uCO2 * N_CO2_factor

        # Escape
        dNdt_e = 0.0
        if N < 0:
            tau_diff = H_CO2_val * N_CO2 / (conv * b_H2CO2)
            if tau_diff > 0:
                dNdt_e = N * (np.exp(-dt / tau_diff) - 1) / dt
                if dNdt_e > abs(N / dt):
                    dNdt_e = abs(N / dt)
        t_esc = max(tt, 1e3)
        PhiO = -4.3e25 * (t_esc / 4.5e3)**(-1.2 * gamma) / Area
        dNdt_e += 2 * conv * PhiO
        dNdt_e += alpha_DH * H_loss_now * np.exp(-(tt - T_total) / td_DH)

        # Weathering
        dNdt_w = 0.0
        if N > 0:
            dNdt_w = -weather_rate
            if abs(dNdt_w) > N / dt:
                dNdt_w = -(N / dt)

        N = N + (dNdt_w + dNdt_e + dNdt_s) * dt

        # Surface temperature
        Fsol = Fsol_arr[it]
        fH2 = 0.0
        if N < 0:
            absN2 = abs(N / 2)
            fH2 = absN2 / (absN2 + N_CO2)
        fH2_lim = min(fH2, fH2_max)
        fCO2 = 1.0 - fH2_lim
        if fCO2 > 1e-10:
            muavg = fCO2 * muCO2 + fH2_lim * muH2
            ps = (uCO2 * g_mars) * (muavg / muCO2) / fCO2
        else:
            ps = uCO2 * g_mars
        Tsurf[it] = trilinear(Fsol, np.log10(max(ps, 1e-10)), fH2_lim,
                               Fsol_fg, lps_fg, fH2_fg, Ts_fg) if ps > 0 else 200.0

    return Tsurf


# ============================================================================
# 3. MODEL SETUP (matching model_setup.m)
# ============================================================================

def setup_model():
    p = {}
    p['y2s'] = 86400 * 365.25
    p['My'] = 1e6 * p['y2s']
    p['T'] = 4.5e3 - 1
    p['nt'] = 44990
    p['dt'] = p['T'] / p['nt']
    p['t_a'] = np.arange(1, p['nt'] + 1) * p['dt']

    p['G'] = 6.67408e-11
    p['muCO2'] = 44.01e-3
    p['muH2'] = 2.016e-3
    p['kB'] = 1.3806503e-23
    p['Rstar'] = 8.314463
    p['NA'] = p['Rstar'] / p['kB']
    p['m_u'] = 1.660539e-27

    p['M_mars'] = 0.64171e24
    p['r_mars'] = 3389.5e3
    p['g_mars'] = p['G'] * p['M_mars'] / p['r_mars']**2
    p['Area'] = 4 * np.pi * p['r_mars']**2
    p['bar'] = 1e5

    p['N_CO2_factor'] = p['Area'] / p['muCO2']
    p['H_loss_now'] = p['My'] * 1.0e27 / p['NA']
    p['Thomo'] = 500.0
    p['H_CO2_val'] = p['kB'] * p['Thomo'] / (p['m_u'] * p['muCO2'] * 1e3 * p['g_mars'])
    p['b_H2CO2'] = 1e2 * 22.3 * p['Thomo']**0.75 / 1e-16
    p['conv'] = p['Area'] * p['My'] / p['NA']
    p['alpha_DH'] = 0.48
    p['td_DH'] = 1.1e3
    p['gamma'] = 1.7
    p['weather_rate'] = 1.58e3
    p['g_exp'] = -1.95
    p['td_supply'] = 1.0e3

    F0 = 1366.0
    d_mars = 1.524
    Fmars0 = F0 / d_mars**2
    p['Fsol_arr'] = Fmars0 / (1 + 0.4 * (1 - p['t_a'] / 4500.0))

    return p


def compute_co2(p, pCO2_3p5Gya=1.5):
    a, t0 = 1.5e-3, 0.8e3
    A = pCO2_3p5Gya / (1 - np.tanh(a * (1e3 - t0)))
    return (p['bar'] * A * (1 - np.tanh(a * (p['t_a'] - t0)))
            / p['g_mars'] + 0.00301 * p['bar'] / p['g_mars'])


def compute_supply(p, beta, Notot):
    g = p['g_exp']
    td = p['td_supply']
    gg = (g + 1) / (g + 2)
    mu0 = Notot / ((1 - np.exp(-p['T'] / td)) * td)
    x1_0 = beta * mu0

    def residual(log_x0):
        x0 = 10**log_x0
        num = x1_0**(g + 2) - x0**(g + 2)
        den = x1_0**(g + 1) - x0**(g + 1)
        if abs(den) < 1e-300:
            return 10.0
        mu_calc = gg * num / den
        return np.log10(max(mu_calc, 1e-300)) - np.log10(mu0)

    try:
        x0_0 = 10**brentq(residual, -20, np.log10(x1_0) - 0.1)
    except ValueError:
        x0_0 = 1e-10 * mu0
    return x0_0, x1_0


# ============================================================================
# 4. EVENT COUNTING
# ============================================================================

def count_events_timespan(Tsurf, dt, T_threshold):
    """Count warming events and compute time-span (first to last event).
    Follows Turner & Kite (2026) definition."""
    above = Tsurf > T_threshold
    if not np.any(above):
        return 0, 0.0
    padded = np.concatenate(([False], above, [False]))
    n_events = np.sum(np.diff(padded.astype(int)) == 1)
    idx = np.where(above)[0]
    return n_events, (idx[-1] - idx[0]) * dt


# ============================================================================
# 5. MAIN
# ============================================================================

if __name__ == '__main__':

    # --- Configuration ---
    PCM_PATH = 'evol_model/PCM_data.mat'  # adjust path as needed
    T_THRESHOLDS = np.array([222, 225, 228, 230, 233, 235, 238, 240,
                              243, 248, 253, 258, 263, 268, 273])
    BETA_VALUES = [100, 300, 1000, 3000, 10000]
    N_REAL = 50
    SEED = 42

    print("Warming event analysis — Supplementary Material")
    print("=" * 60)

    # Setup
    p = setup_model()
    uCO2 = compute_co2(p)
    Fsol_a, log_ps_a, fH2_a, Ts_a = load_pcm_data(PCM_PATH)
    Fsol_fg, lps_fg, fH2_fg, Ts_fg = build_fine_grid(Fsol_a, log_ps_a, fH2_a, Ts_a)
    Notot = 1.5e4 * p['My'] * p['T']

    # Warmup numba JIT
    x0_t, x1_t = compute_supply(p, 1000, Notot)
    _ = run_single(100, p['dt'], p['T'], p['g_exp'], x0_t, x1_t,
                   p['td_supply'], np.random.rand(100),
                   p['Fsol_arr'][:100], uCO2[:100],
                   Fsol_fg, lps_fg, fH2_fg, Ts_fg,
                   p['muCO2'], p['muH2'], p['g_mars'], p['Area'],
                   p['NA'], p['My'], p['H_CO2_val'], p['b_H2CO2'],
                   p['conv'], p['weather_rate'], p['alpha_DH'],
                   p['H_loss_now'], p['td_DH'], p['gamma'],
                   p['N_CO2_factor'], fH2_a[-1])

    # Run ensemble
    np.random.seed(SEED)
    all_n, all_d, all_T, all_beta, all_real = [], [], [], [], []

    t0 = timer.time()
    for beta in BETA_VALUES:
        x0_0, x1_0 = compute_supply(p, beta, Notot)
        for ir in range(N_REAL):
            rands = np.random.rand(p['nt'])
            Tsurf = run_single(
                p['nt'], p['dt'], p['T'], p['g_exp'], x0_0, x1_0,
                p['td_supply'], rands, p['Fsol_arr'], uCO2,
                Fsol_fg, lps_fg, fH2_fg, Ts_fg,
                p['muCO2'], p['muH2'], p['g_mars'], p['Area'],
                p['NA'], p['My'], p['H_CO2_val'], p['b_H2CO2'],
                p['conv'], p['weather_rate'], p['alpha_DH'],
                p['H_loss_now'], p['td_DH'], p['gamma'],
                p['N_CO2_factor'], fH2_a[-1])
            for T_thresh in T_THRESHOLDS:
                n_ev, t_span = count_events_timespan(Tsurf, p['dt'], T_thresh)
                all_n.append(n_ev)
                all_d.append(t_span)
                all_T.append(T_thresh)
                all_beta.append(beta)
                all_real.append(ir)

    elapsed = timer.time() - t0
    total = len(BETA_VALUES) * N_REAL
    print(f"Completed {total} runs in {elapsed:.1f}s ({elapsed/total:.3f}s per run)")

    all_n = np.array(all_n)
    all_d = np.array(all_d)
    all_T = np.array(all_T)
    all_beta = np.array(all_beta)
    all_real = np.array(all_real)

    # --- Save raw data ---
    with open('warming_events_data.csv', 'w', newline='') as f:
        w = csv.writer(f)
        w.writerow(['T_threshold_K', 'beta', 'realization',
                     'N_warm_events', 'Timespan_Myr'])
        for i in range(len(all_n)):
            w.writerow([all_T[i], all_beta[i], all_real[i],
                        all_n[i], f"{all_d[i]:.1f}"])
    print("Saved: warming_events_data.csv")

    # --- Save summary ---
    with open('warming_events_summary.csv', 'w', newline='') as f:
        w = csv.writer(f)
        w.writerow(['T_threshold_K', 'N_events_median', 'N_events_P10',
                     'N_events_P90', 'Timespan_Myr_median',
                     'Timespan_Myr_P10', 'Timespan_Myr_P90'])
        for T in sorted(T_THRESHOLDS, reverse=True):
            mask = (all_T == T) & (all_n > 0)
            if mask.sum() > 0:
                ns, ds = all_n[mask], all_d[mask]
                w.writerow([T, f"{np.median(ns):.0f}",
                            f"{np.percentile(ns,10):.0f}",
                            f"{np.percentile(ns,90):.0f}",
                            f"{np.median(ds):.0f}",
                            f"{np.percentile(ds,10):.0f}",
                            f"{np.percentile(ds,90):.0f}"])
    print("Saved: warming_events_summary.csv")

    # --- Figure ---
    fig, ax = plt.subplots(1, 1, figsize=(5.5, 4.5))
    cmap = plt.cm.RdYlBu_r
    norm = mcolors.Normalize(vmin=220, vmax=275)

    mask = all_n > 0
    plot_n = all_n[mask].astype(float)
    plot_d = all_d[mask].astype(float)
    plot_T = all_T[mask].astype(float)

    sort_idx = np.argsort(plot_T)[::-1]
    sc = ax.scatter(plot_n[sort_idx], plot_d[sort_idx], c=plot_T[sort_idx],
                    cmap=cmap, norm=norm, s=12, alpha=0.4, edgecolors='none',
                    zorder=3, rasterized=True)

    cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02, aspect=25)
    cbar.set_label('$T$ threshold (K)', fontsize=9)
    cbar.set_ticks([222, 230, 240, 250, 260, 273])
    cbar.ax.tick_params(labelsize=7.5)

    ax.annotate(
        '[EMIM][FeBr₄] + CO₂\n$T_g$ = 201 K\nliquid at baseline (222 K)',
        xy=(0.96, 0.96), xycoords='axes fraction',
        fontsize=8.5, fontweight='bold', color='#145a32',
        ha='right', va='top',
        bbox=dict(boxstyle='round,pad=0.35', fc='#e8f5e8', ec='#27ae60',
                  lw=1.0, alpha=0.95), zorder=15)

    ax.set_xlim(0, 180)
    ax.set_ylim(0, 1700)
    ax.set_xlabel('Number of warm events', fontsize=10)
    ax.set_ylabel('Total time-span of warm events (Myr)', fontsize=10)
    ax.tick_params(labelsize=8, which='both', direction='in',
                   top=True, right=True)

    plt.tight_layout(pad=0.5)
    fig.savefig('Figure_SI_warming_events.pdf', dpi=600, bbox_inches='tight')
    fig.savefig('Figure_SI_warming_events.png', dpi=300, bbox_inches='tight')
    print("Saved: Figure_SI_warming_events.pdf")
    print("Done.")
